#include <DB/Interpreters/InJoinSubqueriesPreprocessor.h>
#include <DB/Interpreters/Context.h>
#include <DB/Storages/StorageDistributed.h>
#include <DB/Parsers/ASTSelectQuery.h>
#include <DB/Parsers/ASTTablesInSelectQuery.h>
#include <DB/Parsers/ASTFunction.h>
#include <DB/Parsers/ASTIdentifier.h>


namespace DB
{

namespace ErrorCodes
{
	extern const int DISTRIBUTED_IN_JOIN_SUBQUERY_DENIED;
	extern const int LOGICAL_ERROR;
}


namespace
{

/** Call a function for each non-GLOBAL subquery in IN or JOIN.
  * Pass to function: AST node with subquery, and AST node with corresponding IN function or JOIN.
  * Consider only first-level subqueries (do not go recursively into subqueries).
  */
template <typename F>
void forEachNonGlobalSubquery(IAST * node, F && f)
{
	if (ASTFunction * function = typeid_cast<ASTFunction *>(node))
	{
		if (function->name == "in" || function->name == "notIn")
		{
			f(function->arguments->children.at(1).get(), function, nullptr);
			return;
		}

		/// Pass into other functions, as subquery could be in aggregate or in lambda functions.
	}
	else if (ASTTablesInSelectQueryElement * join = typeid_cast<ASTTablesInSelectQueryElement *>(node))
	{
		if (join->table_join && join->table_expression)
		{
			auto & table_join = static_cast<ASTTableJoin &>(*join->table_join);
			if (table_join.locality != ASTTableJoin::Locality::Global)
			{
				auto & subquery = static_cast<ASTTableExpression &>(*join->table_expression).subquery;
				if (subquery)
					f(subquery.get(), nullptr, &table_join);
			}
			return;
		}

		/// Pass into other kind of JOINs, as subquery could be in ARRAY JOIN.
	}

	/// Descent into all children, but not into subqueries of other kind (scalar subqueries), that are irrelevant to us.
	for (auto & child : node->children)
		if (!typeid_cast<ASTSelectQuery *>(child.get()))
			forEachNonGlobalSubquery(child.get(), f);
}


/** Find all (ordinary) tables in any nesting level in AST.
  */
template <typename F>
void forEachTable(IAST * node, F && f)
{
	if (auto table_expression = typeid_cast<ASTTableExpression *>(node))
	{
		auto & database_and_table = table_expression->database_and_table_name;
		if (database_and_table)
			f(database_and_table);
	}

	for (auto & child : node->children)
		forEachTable(child.get(), f);
}


StoragePtr tryGetTable(const ASTPtr & database_and_table, const Context & context)
{
	String database;
	String table;

	const ASTIdentifier * id = static_cast<const ASTIdentifier *>(database_and_table.get());

	if (id->children.empty())
		table = id->name;
	else if (id->children.size() == 2)
	{
		database = static_cast<const ASTIdentifier *>(id->children[0].get())->name;
		table = static_cast<const ASTIdentifier *>(id->children[1].get())->name;
	}
	else
		throw Exception("Logical error: unexpected number of components in table expression", ErrorCodes::LOGICAL_ERROR);

	return context.tryGetTable(database, table);
}


void replaceDatabaseAndTable(ASTPtr & database_and_table, const String & database_name, const String & table_name)
{
	ASTPtr table = std::make_shared<ASTIdentifier>(StringRange(), table_name, ASTIdentifier::Table);

	if (!database_name.empty())
	{
		ASTPtr database = std::make_shared<ASTIdentifier>(StringRange(), database_name, ASTIdentifier::Database);

		database_and_table = std::make_shared<ASTIdentifier>(
			StringRange(), database_name + "." + table_name, ASTIdentifier::Table);
		database_and_table->children = {database, table};
	}
	else
	{
		database_and_table = std::make_shared<ASTIdentifier>(
			StringRange(), table_name, ASTIdentifier::Table);
	}
}

}


void InJoinSubqueriesPreprocessor::process(ASTSelectQuery * query) const
{
	if (!query)
		return;

	const SettingDistributedProductMode distributed_product_mode = context.getSettingsRef().distributed_product_mode;

	if (distributed_product_mode == DistributedProductMode::ALLOW)
		return;

	ASTPtr table = query->table();
	if (!table)
		return;

	/// If not ordinary table, skip it.
	auto table_expression = typeid_cast<ASTIdentifier *>(table.get());
	if (!table_expression)
		return;

	/// If not really distributed table, skip it.
	StoragePtr storage = tryGetTable(table, context);
	if (!storage || !hasAtLeastTwoShards(*storage))
		return;

	forEachNonGlobalSubquery(query, [&] (IAST * subquery, IAST * function, IAST * table_join)
	{
 		forEachTable(subquery, [&] (ASTPtr & database_and_table)
		{
			StoragePtr storage = tryGetTable(database_and_table, context);

			if (!storage || !hasAtLeastTwoShards(*storage))
				return;

			if (distributed_product_mode == DistributedProductMode::DENY)
			{
				throw Exception("Double-distributed IN/JOIN subqueries is denied (distributed_product_mode = 'deny')."
					" You may rewrite query to use local tables in subqueries, or use GLOBAL keyword, or set distributed_product_mode to suitable value.",
					ErrorCodes::DISTRIBUTED_IN_JOIN_SUBQUERY_DENIED);
			}
			else if (distributed_product_mode == DistributedProductMode::GLOBAL)
			{
				if (function)
				{
					ASTFunction * concrete = static_cast<ASTFunction *>(function);

					if (concrete->name == "in")
						concrete->name = "globalIn";
					else if (concrete->name == "notIn")
						concrete->name = "globalNotIn";
					else if (concrete->name == "globalIn" || concrete->name == "globalNotIn")
					{
						/// Already processed.
					}
					else
						throw Exception("Logical error: unexpected function name " + concrete->name, ErrorCodes::LOGICAL_ERROR);
				}
				else if (table_join)
					static_cast<ASTTableJoin &>(*table_join).locality = ASTTableJoin::Locality::Global;
				else
					throw Exception("Logical error: unexpected AST node", ErrorCodes::LOGICAL_ERROR);
			}
			else if (distributed_product_mode == DistributedProductMode::LOCAL)
			{
				/// Convert distributed table to corresponding remote table.

				std::string database;
				std::string table;
				std::tie(database, table) = getRemoteDatabaseAndTableName(*storage);

				replaceDatabaseAndTable(database_and_table, database, table);
			}
			else
				throw Exception("InJoinSubqueriesPreprocessor: unexpected value of 'distributed_product_mode' setting", ErrorCodes::LOGICAL_ERROR);
		});
	});
}


bool InJoinSubqueriesPreprocessor::hasAtLeastTwoShards(const IStorage & table) const
{
	if (!table.isRemote())
		return false;

	const StorageDistributed * distributed = typeid_cast<const StorageDistributed *>(&table);
	if (!distributed)
		return false;

	return distributed->getShardCount() >= 2;
}


std::pair<std::string, std::string>
InJoinSubqueriesPreprocessor::getRemoteDatabaseAndTableName(const IStorage & table) const
{
	const StorageDistributed & distributed = typeid_cast<const StorageDistributed &>(table);
	return { distributed.getRemoteDatabaseName(), distributed.getRemoteTableName() };
}


}
